import sys
import argparse
import json
import os
import mxnet as mx
import caffe
from prototxt_basic import *

def json2prototxt(mx_json, cf_prototxt):
    with open(mx_json) as json_file:    
      jdata = json.load(json_file)

    with open(cf_prototxt, "w") as prototxt_file:
      for i_node in range(0,len(jdata['nodes'])):
        node_i    = jdata['nodes'][i_node]
        if str(node_i['op']) == 'null' and str(node_i['name']) != 'data':
          continue
        if str(node_i['op']) == '_copy' or str(node_i['op']) == '_minus_scalar' or str(node_i['op']) == '_mul_scalar' or str(node_i['op']) == 'Dropout':
          continue
        
        print('{}, \top:{}, name:{} -> {}'.format(i_node,node_i['op'].ljust(20),
                                            node_i['name'].ljust(30),
                                            node_i['name']).ljust(20))
        info = node_i
        
        info['top'] = info['name']
        info['bottom'] = []
        info['params'] = []
        for input_idx_i in node_i['inputs']:
          input_i = jdata['nodes'][input_idx_i[0]]
          if str(input_i['op']) == '_mul_scalar':
            info['bottom'].append('data')
          elif str(input_i['op']) == 'Dropout':
            current_input = input_i['inputs']
            current_input = current_input[0]
            input_idx = jdata['nodes'][current_input[0]]
            info['bottom'].append(str(input_idx['name']))
          elif str(input_i['op']) != 'null' or (str(input_i['name']) == 'data'):
            info['bottom'].append(str(input_i['name']))
          if str(input_i['op']) == 'null':
            #if str(input_i['name']) == 'fc1_weight':
            # print("fullyconnect : ",input_i['name'])
              #info['params'].append('pre_fc1_weight')
            #else:
            info['params'].append(str(input_i['name']))
            if not str(input_i['name']).startswith(str(node_i['name'])):
              print('           use shared weight -> %s'% str(input_i['name']))
              info['share'] = True
        write_node(prototxt_file, info)

def mxnet2caffe(mx_model, mx_epoch, cf_prototxt, cf_model):
    print("-------input mx_model:%s, mx_epoch:%d, cf_prototxt:%s, cf_model:%s" %(mx_model, mx_epoch, cf_prototxt, cf_model))
    # ------------------------------------------
    # Load
    syms, arg_params, aux_params = mx.model.load_checkpoint(mx_model, mx_epoch)
    print("-------load mxnet model successful")
    net = caffe.Net(cf_prototxt, caffe.TEST)
    print("-------load caffe prototxt success")
    # ------------------------------------------
    # Convert
    #all_keyss = arg_params.keys() + aux_params.keys()
    all_keyss = list(arg_params.keys()) + list(aux_params.keys())
    all_keyss.sort()
    all_keys = all_keyss
    #all_keys[78]=all_keyss[79]
    #all_keys[79]='pre_fc1_bias'
    print('----------------------------------\n')
    print('ALL KEYS IN MXNET:')
    print(all_keys)
    print('%d KEYS' %len(all_keys))

    print('----------------------------------\n')
    print('VALID KEYS:')
    for i_key,key_i in enumerate(all_keys):
      try:
        if 'data' is key_i:
          pass
        # elif 'fc1_weight' in key_i:
        #   key_caffe = key_i.replace('fc1_weight', 'fc1')
        #   net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
        elif '_weight' in key_i:
          key_caffe = key_i.replace('_weight','')
          if 'fc1' in key_i:
            print(key_i)
            print(arg_params[key_i].shape)
            key_caffe = 'pre_fc1'
            print(net.params[key_caffe][0].data.shape)
          net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat      
        elif '_bias' in key_i:
          key_caffe = key_i.replace('_bias','')
          net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat   
        elif '_gamma' in key_i and 'relu' not in key_i:
          key_caffe = key_i.replace('_gamma','_scale')
          net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
        # TODO: support prelu
        elif '_gamma' in key_i:   # for prelu
          key_caffe = key_i.replace('_gamma','')
          assert (len(net.params[key_caffe]) == 1)
          net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
        elif '_beta' in key_i:
          key_caffe = key_i.replace('_beta','_scale')
          net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat    
        elif '_moving_mean' in key_i and 'fc1' not in key_i:
          key_caffe = key_i.replace('_moving_mean','')
          net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        elif '_moving_var' in key_i and 'fc1' not in key_i:
          key_caffe = key_i.replace('_moving_var','')
          net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        elif '_moving_mean' in key_i:
          key_caffe = key_i.replace('_moving_mean', '')
          net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        elif '_moving_var' in key_i:
          key_caffe = key_i.replace('_moving_var', '')
          net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        elif '_running_mean' in key_i:
          key_caffe = key_i.replace('_running_mean', '')
          net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        elif '_running_var' in key_i:
          key_caffe = key_i.replace('_running_var', '')
          net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
          net.params[key_caffe][2].data[...] = 1
        else:
          sys.exit("Warning!  Unknown mxnet:{}".format(key_i))
        print("% 3d | %s -> %s, initialized." 
              %(i_key, key_i.ljust(40), key_caffe.ljust(30)))
      except KeyError:
        print("\nError!  key error mxnet:{}".format(key_i))

    # fc1 = mx.sym.BatchNorm(data=key_i, fix_gamma=True, eps=2e-5, momentum=0.9, name='fc1')
    # ------------------------------------------
    # Finish
    net.save(cf_model)
    print("Mxnet to caffe Finished!\n")

def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default="./net_inst.txt", type=str, required=True, help='config file for inference')
    return parser

if __name__ == "__main__":
    myparser = make_parser()
    args = myparser.parse_args()
    config_file = args.config
    file_config = open(args.config).read()
    lines_config = file_config.split("\n")
    for i in range(0,len(lines_config)):
        if "mxnet_json" in lines_config[i]:
            mxnet_json=lines_config[i].split("=")[-1]
        elif "mxnet_checkpoint" in lines_config[i]:
            mxnet_checkpoint=lines_config[i].split("=")[-1]
        elif "mxnet_epoch" in lines_config[i]:
            mxnet_epoch=lines_config[i].split("=")[-1]
        elif "inst_file" in lines_config[i]:
            caffe_prototxt = lines_config[i].split("=")[-1] + ".prototxt"
            caffe_model = lines_config[i].split("=")[-1] + ".caffemodel"

    json2prototxt(mxnet_json, caffe_prototxt)
    mxnet2caffe(mxnet_checkpoint, int(mxnet_epoch), caffe_prototxt, caffe_model)
